/*
 * cascade.h
 */

#ifndef EXAMPLES_CPP_CASCADE_H_
#define EXAMPLES_CPP_CASCADE_H_

#include "dynet/dynet.h"
#include "dynet/training.h"
#include "dynet/lstm.h"
using namespace std;
using namespace dynet;

class multitask {
public:
  /**
   * \brief Initializes vocabulary, builders and parameters
   *
   * \param model ParameterCollection holding the parameters
   */
  virtual void initialize(ParameterCollection& model) = 0;
  virtual void initialize_extra(ParameterCollection& model) = 0;
  virtual void initialize_partial(ParameterCollection& model) = 0;

  /**
   * \brief computes loss for the network for a triple of (input_sentence, int_sentence, output_sentence)
   *
   * \param enc_fwd_lstm forward lstm
   * \param enc_bwd_lstm backward lstm
   * \param enc_compress_lstm1 first forward compressing lstm 
   * \param enc_compress_lstm2 second forward compressing lstm 
   * \param dec_1_lstm Decoder lstm for the first task
   * \param dec_2_lstm Decoder lstm for the second task
   */
  virtual Expression get_loss(const vector<vector<float>>&  input_sentence, const vector<int>&  int_sentence, const vector<int>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
      LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg) = 0;
  virtual Expression get_loss(const vector<vector<vector<float>>>&  input_sentence, const vector<vector<int>>&  int_sentence, const vector<vector<int>>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
      LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg) = 0;


    /**
   * \brief executes code to train the network
   *
   * \param model ParameterCollection holding the parameters
   * \param in_seq Input feature sequences
   * \param int_seq the first target sequence
   * \param out_seq the second target sequence
   * \param trainer Trainer instance
   */
  //virtual float train(ParameterCollection& model, const vector<int>&  in_seq, const vector<int>&  int_seq,  const vector<int>&  out_seq, AdamTrainer& trainer, dynet::real l_scale);
  virtual float train(ParameterCollection& model, const vector<vector<vector<float>>>&  in_seq, const vector<vector<int>>&  int_seq,  const vector<vector<int>>&  out_seq, AdamTrainer& trainer, dynet::real l_scale) = 0;
  virtual float test_dev(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq) = 0;
  //virtual float test_dev_bleu(ParameterCollection& model, const vector<int>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq) = 0;

  virtual void test(ParameterCollection& model, const vector<vector<float>>&  in_seq, int beamsize) = 0;

  /**
    * \brief runs the two decoders lstm with attention over the input encoded sequence and computes loss
    *
    * \param dec_1_lstm First Decoder lstm
    * \param dec_2_lstm Second Decoder lstm
    * \param cg Computation graph
    */
  virtual Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg) = 0;
  virtual Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, const vector<vector<int>>& trg_batch, ComputationGraph& cg) = 0;

  /**
    * \brief Prints the attention matrices during forced decoding
    */
  virtual void dump_attentions(ParameterCollection& model, const vector<vector<float>>& in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq) = 0;
  /**
    * \brief Computes attention values using input and the lstm states of input
    *
    * \param input_mat encoded input
    * \param state decoder lstm
    * \param w1dt input weighted by w1
    * \param cg Computation graph
    */
  virtual Expression attend_1(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg) = 0;
  
  
  /**
   * \brief encodes input sentence using bidirectional lstm
   *
   * \param enc_fwd_lstm forward lstm
   * \param enc_bwd_lstm backward lstm
   * \param embedded input embeddings
   */
  vector<Expression> encode_sentence(LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, vector<Expression>& embedded);
  /**
   * \brief encodes input feature sequence using a 3-layer pyramidal encoder
   *
   * \param enc_fwd_lstm forward lstm
   * \param enc_bwd_lstm backward lstm
   * \param enc_compress_lstm1 first forward compressing lstm 
   * \param enc_compress_lstm2 second forward compressing lstm 
   */
  vector<Expression> encode_features(LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, vector<Expression>& embedded);

  /**
   * \brief runs lstm over the input embeddings
   *
   * \param init_state lstm instance
   * \param input_vecs input character embeddings
   */
  vector<Expression> run_lstm(LSTMBuilder& init_state, const vector<Expression>& input_vecs);

  /**
   * \brief constructs input sentence embeddings, or reads input feature sequences
   *
   * \param sentence input sentence
   * \param cg computation graph instance
   */
  vector<Expression> embed_sentence(const vector<int>&  sentence, ComputationGraph& cg);
  vector<Expression> embed_sentence(const vector<vector<int>>&  batch, ComputationGraph& cg);
  vector<Expression> embed_features(const vector<vector<float>>&  sentence, ComputationGraph& cg);
  vector<Expression> embed_features(const vector<vector<vector<float>>>&  batch, ComputationGraph& cg);
  vector<Expression> embed_stack_features(const vector<vector<float>>&  sentence, ComputationGraph& cg);
  vector<Expression> embed_stack_features(const vector<vector<vector<float>>>&  batch, ComputationGraph& cg);

  // Used for evauation
  float bleu(const vector<int>& hyp, const vector<int>& ref);
  map<vector<int>,int> get_ngrams(const vector<int>& sentence);


//private:
  static const int REP_SIZE = 512;
  Parameter decoder_1_w;
  Parameter decoder_1_b;

  Parameter decoder_2_w;
  Parameter decoder_2_b;

  Parameter attention_1_w1;
  Parameter attention_1_w2;
  Parameter attention_1_v;

  Parameter attention_2_w1;
  Parameter attention_2_w2;
  Parameter attention_2_v;

  Parameter attention_3_w1;
  Parameter attention_3_w2;
  Parameter attention_3_v;

  LookupParameter output_1_lookup;
  LookupParameter output_2_lookup;
};

// Implements a cascade model, similar to a reconstruction one
class cascade : public multitask
{
public:
  void initialize(ParameterCollection& model);
  void initialize_extra(ParameterCollection& model);
  void initialize_partial(ParameterCollection& model);

  Expression get_loss(const vector<vector<float>>&  input_sentence, const vector<int>&  int_sentence, const vector<int>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm,
      LSTMBuilder& dec_2_lstm, ComputationGraph& cg);
  Expression get_loss(const vector<vector<vector<float>>>&  input_sentence, const vector<vector<int>>&  int_sentence, const vector<vector<int>>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm,
      LSTMBuilder& dec_2_lstm, ComputationGraph& cg);
  //float train(ParameterCollection& model, const vector<int>&  in_seq, const vector<int>&  int_seq,  const vector<int>&  out_seq, AdamTrainer& trainer, dynet::real l_scale);
  float train(ParameterCollection& model, const vector<vector<vector<float>>>&  in_seq, const vector<vector<int>>&  int_seq,  const vector<vector<int>>&  out_seq, AdamTrainer& trainer, dynet::real l_scale);
  float test_dev(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq);
  float test_dev_bleu(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq);
  void test(ParameterCollection& model, const vector<vector<float>>&  in_seq, int beamsize);
  tuple<vector<int>, vector<int>> generate_nbest(const vector<vector<float>>&  in_seq, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg, int nbest_1_size, int nbest_2_size, int beamsize);
  
  Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg);
  Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, const vector<vector<int>>& trg_batch, ComputationGraph& cg);
  
  Expression reg_joint(vector<Expression>& att_21, vector<Expression>& att_32, ComputationGraph& cg);

  Expression attend_1(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg);
  Expression attend_3(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg);

  void dump_attentions(ParameterCollection& model, const vector<vector<float>>& in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq);
  void decode_attentions(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg);

};

// Implement a triengle model
class triangle : public multitask
{
public:
  void initialize(ParameterCollection& model);
  void initialize_extra(ParameterCollection& model);
  void initialize_partial(ParameterCollection& model);

  Expression get_loss(const vector<vector<float>>&  input_sentence, const vector<int>&  int_sentence, const vector<int>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm,
      LSTMBuilder& dec_2_lstm, ComputationGraph& cg);
  Expression get_loss(const vector<vector<vector<float>>>&  input_sentence, const vector<vector<int>>&  int_sentence, const vector<vector<int>>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm,
      LSTMBuilder& dec_2_lstm, ComputationGraph& cg);
  float train(ParameterCollection& model, const vector<vector<vector<float>>>&  in_seq, const vector<vector<int>>&  int_seq,  const vector<vector<int>>&  out_seq, AdamTrainer& trainer, dynet::real l_scale);
  float test_dev(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq);
  float test_dev_bleu(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq);
  void test(ParameterCollection& model, const vector<vector<float>>&  in_seq, int beamsize);
  tuple<vector<int>, vector<int>> generate_nbest(const vector<vector<float>>&  in_seq, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg, int nbest_1_size, int nbest_2_size, int beamsize);
  tuple<vector<int>, vector<int>> generate_nbest_asr(const vector<vector<float>>&  in_seq, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg, int nbest_1_size, int nbest_2_size, int beamsize);
  
  Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg);
  Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, const vector<vector<int>>& trg_batch, ComputationGraph& cg);
  
  Expression reg_joint(vector<Expression>& att_21, vector<Expression>& att_32, vector<Expression>& att_31);

  Expression attend_1(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg);
  Expression attend_2(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg);
  Expression attend_3(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg);

  void dump_attentions(ParameterCollection& model, const vector<vector<float>>& in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq);
  void decode_attentions(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg);

};

// Implement a simple multitask scenario
class simplemultitask : public multitask
{
public:
  void initialize(ParameterCollection& model);
  void initialize_extra(ParameterCollection& model);
  void initialize_partial(ParameterCollection& model);

  Expression get_loss(const vector<vector<float>>&  input_sentence, const vector<int>&  int_sentence, const vector<int>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm,
      LSTMBuilder& dec_2_lstm, ComputationGraph& cg);
  Expression get_loss(const vector<vector<vector<float>>>&  input_sentence, const vector<vector<int>>&  int_sentence, const vector<vector<int>>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm,
      LSTMBuilder& dec_2_lstm, ComputationGraph& cg);
  //float train(ParameterCollection& model, const vector<int>&  in_seq, const vector<int>&  int_seq,  const vector<int>&  out_seq, AdamTrainer& trainer, dynet::real l_scale);
  float train(ParameterCollection& model, const vector<vector<vector<float>>>&  in_seq, const vector<vector<int>>&  int_seq,  const vector<vector<int>>&  out_seq, AdamTrainer& trainer, dynet::real l_scale);
  float test_dev(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq);
  float test_dev_bleu(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq);
  void test(ParameterCollection& model, const vector<vector<float>>&  in_seq, int beamsize);
  tuple<vector<int>, vector<int>> generate_nbest(const vector<vector<float>>&  in_seq, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg, int nbest_1_size, int nbest_2_size, int beamsize);
  
  Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg);
  Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, const vector<vector<int>>& trg_batch, ComputationGraph& cg);
  
  Expression attend_1(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg);
  Expression attend_2(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg);

  void dump_attentions(ParameterCollection& model, const vector<vector<float>>& in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq);
  void decode_attentions(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg);

};

// Implement a simple unitask scenario (only one decoder used)
class simpleunitask : public multitask
{
public:
  void initialize(ParameterCollection& model);
  void initialize_extra(ParameterCollection& model){
    initialize(model);
  };
  void initialize_partial(ParameterCollection& model){
    initialize(model);
  };

  Expression get_loss(const vector<vector<float>>&  input_sentence, const vector<int>&  int_sentence, const vector<int>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
      LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg) {
      return get_loss(input_sentence,int_sentence,enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm1, enc_compress_lstm2, dec_1_lstm, cg);
  };

  Expression get_loss(const vector<vector<vector<float>>>&  input_sentence, const vector<vector<int>>&  int_sentence, const vector<vector<int>>&  output_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, 
      LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, ComputationGraph& cg){
      return get_loss(input_sentence,int_sentence,enc_fwd_lstm, enc_bwd_lstm, enc_compress_lstm1, enc_compress_lstm2, dec_1_lstm, cg);
  };

  Expression get_loss(const vector<vector<float>>&  input_sentence, const vector<int>&  int_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm,
    LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, ComputationGraph& cg);
  Expression get_loss(const vector<vector<vector<float>>>&  input_sentence, const vector<vector<int>>&  int_sentence, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm,
    LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, ComputationGraph& cg);

  float train(ParameterCollection& model, const vector<vector<vector<float>>>&  in_seq, const vector<vector<int>>&  int_seq,  const vector<vector<int>>&  out_seq, AdamTrainer& trainer, dynet::real l_scale)
  {
      return  train(model, in_seq, int_seq, trainer, l_scale);    
  };
  float train(ParameterCollection& model, const vector<vector<vector<float>>>&  in_seq, const vector<vector<int>>&  int_seq,  AdamTrainer& trainer, dynet::real l_scale);
  
  float test_dev(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq)
  {
    return test_dev(model, in_seq, int_seq); 
  };
  float test_dev(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq);
  
  float test_dev_bleu(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq){
    return test_dev_bleu(model, in_seq, int_seq);
  };
  float test_dev_bleu(ParameterCollection& model, const vector<vector<float>>&  in_seq, const vector<int>&  int_seq);

  void test(ParameterCollection& model, const vector<vector<float>>&  in_seq, int beamsize);

  vector<int> generate_nbest(const vector<vector<float>>&  in_seq, LSTMBuilder& enc_fwd_lstm, LSTMBuilder& enc_bwd_lstm, LSTMBuilder& enc_compress_lstm1, LSTMBuilder& enc_compress_lstm2, LSTMBuilder& dec_1_lstm, ComputationGraph& cg, int nbest_1_size, int beamsize);
  
  Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, const vector<int>& trg_sentence, ComputationGraph& cg)
  {
    return decode( dec_1_lstm, encoded, int_sentence, cg);
  };
  Expression decode(LSTMBuilder& dec_1_lstm, LSTMBuilder& dec_2_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, const vector<vector<int>>& trg_batch, ComputationGraph& cg)
  {
    return decode( dec_1_lstm, encoded, int_batch, cg);
  };

  Expression decode(LSTMBuilder& dec_1_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, ComputationGraph& cg);
  Expression decode(LSTMBuilder& dec_1_lstm, vector<Expression>& encoded, const vector<vector<int>>& int_batch, ComputationGraph& cg);
  
  Expression attend_1(LSTMBuilder& state, Expression w1dt, ComputationGraph& cg);

  void dump_attentions(ParameterCollection& model, const vector<vector<float>>& in_seq, const vector<int>&  int_seq, const vector<int>&  out_seq){
	  return dump_attentions(model, in_seq, int_seq);
  };
  void dump_attentions(ParameterCollection& model, const vector<vector<float>>& in_seq, const vector<int>& int_seq);
  void decode_attentions(LSTMBuilder& dec_1_lstm, vector<Expression>& encoded, const vector<int>& int_sentence, ComputationGraph& cg);

};



// Used during beam search
class DecoderHyp {
public:
    DecoderHyp(float score, const vector<Expression>& states, const vector<int> & sent) :
        score_(score), states_(states), sent_(sent) { }

    float GetScore() const { return score_; }
    const vector<Expression>& GetStates() const { return states_; }
    const vector<int>& GetSentence() const { return sent_; }

protected:

    float score_;
    vector<Expression> states_;
    vector<int> sent_;

};

typedef std::shared_ptr<DecoderHyp> DecoderHypPtr;

inline bool operator<(const DecoderHypPtr & lhs, const DecoderHypPtr & rhs) {
  assert(lhs.get() != nullptr);
  assert(rhs.get() != nullptr);
  if(lhs->GetScore() != rhs->GetScore()) return lhs->GetScore() > rhs->GetScore();
  return lhs->GetSentence() < rhs->GetSentence();
}


#endif /* EXAMPLES_CPP_CASCADE_H_ */
